#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Example Airflow DAG for Google ML Engine service.
"""
import os
from typing import Dict

from airflow import models
from airflow.operators.bash import BashOperator
from airflow.providers.google.cloud.operators.mlengine import (
    MLEngineCreateModelOperator,
    MLEngineCreateVersionOperator,
    MLEngineDeleteModelOperator,
    MLEngineDeleteVersionOperator,
    MLEngineGetModelOperator,
    MLEngineListVersionsOperator,
    MLEngineSetDefaultVersionOperator,
    MLEngineStartBatchPredictionJobOperator,
    MLEngineStartTrainingJobOperator,
)
from airflow.providers.google.cloud.utils import mlengine_operator_utils
from airflow.utils.dates import days_ago

PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")

MODEL_NAME = os.environ.get("GCP_MLENGINE_MODEL_NAME", "model_name")

SAVED_MODEL_PATH = os.environ.get("GCP_MLENGINE_SAVED_MODEL_PATH", "gs://INVALID BUCKET NAME/saved-model/")
JOB_DIR = os.environ.get("GCP_MLENGINE_JOB_DIR", "gs://INVALID BUCKET NAME/keras-job-dir")
PREDICTION_INPUT = os.environ.get(
    "GCP_MLENGINE_PREDICTION_INPUT", "gs://INVALID BUCKET NAME/prediction_input.json"
)
PREDICTION_OUTPUT = os.environ.get(
    "GCP_MLENGINE_PREDICTION_OUTPUT", "gs://INVALID BUCKET NAME/prediction_output"
)
TRAINER_URI = os.environ.get("GCP_MLENGINE_TRAINER_URI", "gs://INVALID BUCKET NAME/trainer.tar.gz")
TRAINER_PY_MODULE = os.environ.get("GCP_MLENGINE_TRAINER_TRAINER_PY_MODULE", "trainer.task")

SUMMARY_TMP = os.environ.get("GCP_MLENGINE_DATAFLOW_TMP", "gs://INVALID BUCKET NAME/tmp/")
SUMMARY_STAGING = os.environ.get("GCP_MLENGINE_DATAFLOW_STAGING", "gs://INVALID BUCKET NAME/staging/")

default_args = {"params": {"model_name": MODEL_NAME}}

with models.DAG(
    "example_gcp_mlengine",
    schedule_interval=None,  # Override to match your needs
    start_date=days_ago(1),
    tags=['example'],
) as dag:
    # [START howto_operator_gcp_mlengine_training]
    training = MLEngineStartTrainingJobOperator(
        task_id="training",
        project_id=PROJECT_ID,
        region="us-central1",
        job_id="training-job-{{ ts_nodash }}-{{ params.model_name }}",
        runtime_version="1.15",
        python_version="3.7",
        job_dir=JOB_DIR,
        package_uris=[TRAINER_URI],
        training_python_module=TRAINER_PY_MODULE,
        training_args=[],
        labels={"job_type": "training"},
    )
    # [END howto_operator_gcp_mlengine_training]

    # [START howto_operator_gcp_mlengine_create_model]
    create_model = MLEngineCreateModelOperator(
        task_id="create-model",
        project_id=PROJECT_ID,
        model={
            "name": MODEL_NAME,
        },
    )
    # [END howto_operator_gcp_mlengine_create_model]

    # [START howto_operator_gcp_mlengine_get_model]
    get_model = MLEngineGetModelOperator(
        task_id="get-model",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
    )
    # [END howto_operator_gcp_mlengine_get_model]

    # [START howto_operator_gcp_mlengine_print_model]
    get_model_result = BashOperator(
        bash_command="echo \"{{ task_instance.xcom_pull('get-model') }}\"",
        task_id="get-model-result",
    )
    # [END howto_operator_gcp_mlengine_print_model]

    # [START howto_operator_gcp_mlengine_create_version1]
    create_version = MLEngineCreateVersionOperator(
        task_id="create-version",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
        version={
            "name": "v1",
            "description": "First-version",
            "deployment_uri": f'{JOB_DIR}/keras_export/',
            "runtime_version": "1.15",
            "machineType": "mls1-c1-m2",
            "framework": "TENSORFLOW",
            "pythonVersion": "3.7",
        },
    )
    # [END howto_operator_gcp_mlengine_create_version1]

    # [START howto_operator_gcp_mlengine_create_version2]
    create_version_2 = MLEngineCreateVersionOperator(
        task_id="create-version-2",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
        version={
            "name": "v2",
            "description": "Second version",
            "deployment_uri": SAVED_MODEL_PATH,
            "runtime_version": "1.15",
            "machineType": "mls1-c1-m2",
            "framework": "TENSORFLOW",
            "pythonVersion": "3.7",
        },
    )
    # [END howto_operator_gcp_mlengine_create_version2]

    # [START howto_operator_gcp_mlengine_default_version]
    set_defaults_version = MLEngineSetDefaultVersionOperator(
        task_id="set-default-version",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
        version_name="v2",
    )
    # [END howto_operator_gcp_mlengine_default_version]

    # [START howto_operator_gcp_mlengine_list_versions]
    list_version = MLEngineListVersionsOperator(
        task_id="list-version",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
    )
    # [END howto_operator_gcp_mlengine_list_versions]

    # [START howto_operator_gcp_mlengine_print_versions]
    list_version_result = BashOperator(
        bash_command="echo \"{{ task_instance.xcom_pull('list-version') }}\"",
        task_id="list-version-result",
    )
    # [END howto_operator_gcp_mlengine_print_versions]

    # [START howto_operator_gcp_mlengine_get_prediction]
    prediction = MLEngineStartBatchPredictionJobOperator(
        task_id="prediction",
        project_id=PROJECT_ID,
        job_id="prediction-{{ ts_nodash }}-{{ params.model_name }}",
        region="us-central1",
        model_name=MODEL_NAME,
        data_format="TEXT",
        input_paths=[PREDICTION_INPUT],
        output_path=PREDICTION_OUTPUT,
        labels={"job_type": "prediction"},
    )
    # [END howto_operator_gcp_mlengine_get_prediction]

    # [START howto_operator_gcp_mlengine_delete_version]
    delete_version = MLEngineDeleteVersionOperator(
        task_id="delete-version", project_id=PROJECT_ID, model_name=MODEL_NAME, version_name="v1"
    )
    # [END howto_operator_gcp_mlengine_delete_version]

    # [START howto_operator_gcp_mlengine_delete_model]
    delete_model = MLEngineDeleteModelOperator(
        task_id="delete-model", project_id=PROJECT_ID, model_name=MODEL_NAME, delete_contents=True
    )
    # [END howto_operator_gcp_mlengine_delete_model]

    training >> create_version
    training >> create_version_2
    create_model >> get_model >> [get_model_result, delete_model]
    create_model >> create_version >> create_version_2 >> set_defaults_version >> list_version
    create_version >> prediction
    create_version_2 >> prediction
    prediction >> delete_version
    list_version >> list_version_result
    list_version >> delete_version
    delete_version >> delete_model

    # [START howto_operator_gcp_mlengine_get_metric]
    def get_metric_fn_and_keys():
        """
        Gets metric function and keys used to generate summary
        """

        def normalize_value(inst: Dict):
            val = float(inst['dense_4'][0])
            return tuple([val])  # returns a tuple.

        return normalize_value, ['val']  # key order must match.

    # [END howto_operator_gcp_mlengine_get_metric]

    # [START howto_operator_gcp_mlengine_validate_error]
    def validate_err_and_count(summary: Dict) -> Dict:
        """
        Validate summary result
        """
        if summary['val'] > 1:
            raise ValueError(f'Too high val>1; summary={summary}')
        if summary['val'] < 0:
            raise ValueError(f'Too low val<0; summary={summary}')
        if summary['count'] != 20:
            raise ValueError(f'Invalid value val != 20; summary={summary}')
        return summary

    # [END howto_operator_gcp_mlengine_validate_error]

    # [START howto_operator_gcp_mlengine_evaluate]
    evaluate_prediction, evaluate_summary, evaluate_validation = mlengine_operator_utils.create_evaluate_ops(
        task_prefix="evaluate-ops",
        data_format="TEXT",
        input_paths=[PREDICTION_INPUT],
        prediction_path=PREDICTION_OUTPUT,
        metric_fn_and_keys=get_metric_fn_and_keys(),
        validate_fn=validate_err_and_count,
        batch_prediction_job_id="evaluate-ops-{{ ts_nodash }}-{{ params.model_name }}",
        project_id=PROJECT_ID,
        region="us-central1",
        dataflow_options={
            'project': PROJECT_ID,
            'tempLocation': SUMMARY_TMP,
            'stagingLocation': SUMMARY_STAGING,
        },
        model_name=MODEL_NAME,
        version_name="v1",
        py_interpreter="python3",
    )
    # [END howto_operator_gcp_mlengine_evaluate]

    create_model >> create_version >> evaluate_prediction
    evaluate_validation >> delete_version
